{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4-1,张量的结构操作\n",
    "\n",
    "张量的操作主要包括张量的结构操作和张量的数学运算。\n",
    "\n",
    "张量结构操作诸如：张量创建，索引切片，维度变换，合并分割。\n",
    "\n",
    "张量数学运算主要有：标量运算，向量运算，矩阵运算。另外我们会介绍张量运算的广播机制。\n",
    "\n",
    "本篇我们介绍张量的结构操作。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 一，创建张量"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "张量创建的许多方法和numpy中创建array的方法很像。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 2 3]\n"
     ]
    }
   ],
   "source": [
    "a = tf.constant([1,2,3], dtype = tf.float32)\n",
    "tf.print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1 3 5 7 9]\n",
      "tf.Tensor([1 3 5 7 9], shape=(5,), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "# tf.range(start,limit=None,delta=1,name='range') 返回一个tensor等差数列，\n",
    "# 该tensor中的数值在start到limit之间，不包括limit，delta是等差数列的差值。\n",
    "b = tf.range(1, 10, delta=2)\n",
    "tf.print(b)\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0 0.0634343475 0.126868695 ... 6.15313148 6.21656609 6.28] 100\n"
     ]
    }
   ],
   "source": [
    "# tf.linspace(start,stop,num,name=None) 返回一个tensor，\n",
    "# 该tensor中的数值在start到stop区间之间取等差数列（包含start和stop）\n",
    "c = tf.linspace(0.0, 2*3.14, 100)\n",
    "tf.print(c, len(c))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0 0 0]\n",
      " [0 0 0]\n",
      " [0 0 0]]\n"
     ]
    }
   ],
   "source": [
    "d = tf.zeros([3, 3])\n",
    "tf.print(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1 1 1]\n",
      " [1 1 1]\n",
      " [1 1 1]]\n",
      "[1 1]\n",
      "[[0 0 0]\n",
      " [0 0 0]\n",
      " [0 0 0]]\n"
     ]
    }
   ],
   "source": [
    "a = tf.ones([3, 3]) # 创建全一矩阵\n",
    "b = tf.ones_like([3, 3]) # 新建一个与给定的tensor类型大小一致的tensor，其所有元素为1。\n",
    "# 新建一个与给定的tensor类型大小一致的tensor，其所有元素为0。\n",
    "c = tf.zeros_like(a ,dtype=tf.float32)\n",
    "\n",
    "tf.print(a)\n",
    "tf.print(b)\n",
    "tf.print(c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[5 5]\n",
      " [5 5]\n",
      " [5 5]]\n"
     ]
    }
   ],
   "source": [
    "# tf.fill(dim,value,name=None) 创建一个形状大小为dim的tensor，其初始值为value\n",
    "b = tf.fill([3,2], 5)\n",
    "tf.print(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1.65130854 9.01481247 6.30974197 4.34546089 2.9193902]\n"
     ]
    }
   ],
   "source": [
    "#均匀分布随机, 元素服从minval和maxval之间的均匀分布。\n",
    "tf.random.set_seed(1.0)\n",
    "a = tf.random.uniform([5], minval=0, maxval=10)\n",
    "tf.print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0.403087884 -1.0880208 -0.0630953535]\n",
      " [1.33655667 0.711760104 -0.489286453]\n",
      " [-0.764221311 -1.03724861 -1.25193381]]\n"
     ]
    }
   ],
   "source": [
    "# 正态分布随机\n",
    "b = tf.random.normal([3, 3], mean=0, stddev=1.0)\n",
    "tf.print(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-0.457012236 -0.406867266 0.728577733 -0.892977774 -0.369404584]\n",
      " [0.323488563 1.19383323 0.888299048 1.25985599 -1.95951891]\n",
      " [-0.202244401 0.294496894 -0.468728036 1.29494202 1.48142183]\n",
      " [0.0810953453 1.63843894 0.556645 0.977199793 -1.17777884]\n",
      " [1.67368948 0.0647980496 -0.705142677 -0.281972528 0.126546144]]\n"
     ]
    }
   ],
   "source": [
    "#正态分布随机，剔除2倍方差以外数据重新生成\n",
    "c = tf.random.truncated_normal((5,5), mean=0, stddev=1.0, dtype=tf.float32)\n",
    "tf.print(c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1 0 0]\n",
      " [0 1 0]\n",
      " [0 0 1]]\n",
      "\n",
      "[[1 0 0]\n",
      " [0 2 0]\n",
      " [0 0 3]]\n"
     ]
    }
   ],
   "source": [
    "# 特殊矩阵\n",
    "I = tf.eye(3, 3) # 单位矩阵\n",
    "tf.print(I)\n",
    "tf.print(\"\")\n",
    "\n",
    "t = tf.linalg.diag([1, 2, 3]) # 对角阵\n",
    "tf.print(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 二 、索引切片"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "张量的索引切片方式和numpy几乎是一样的。切片时支持缺省参数和省略号。\n",
    "\n",
    "对于tf.Variable,可以通过索引和切片对部分元素进行修改。\n",
    "\n",
    "对于提取张量的连续子区域，也可以使用tf.slice.\n",
    "\n",
    "此外，对于不规则的切片提取,可以使用tf.gather, tf.gather_nd, tf.boolean_mask。\n",
    "\n",
    "tf.boolean_mask功能最为强大，它可以实现tf.gather, tf.gather_nd的功能，并且tf.boolean_mask还可以实现布尔索引。\n",
    "\n",
    "如果要通过修改张量的某些元素得到新的张量，可以使用tf.where, tf.scatter_nd。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[4 7 4 2 9]\n",
      " [9 1 2 4 7]\n",
      " [7 2 7 4 0]\n",
      " [9 6 9 7 2]\n",
      " [3 7 0 0 3]]\n"
     ]
    }
   ],
   "source": [
    "tf.random.set_seed(3)\n",
    "t = tf.random.uniform([5, 5], minval=0, maxval=10, dtype=tf.int32)\n",
    "tf.print(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "第0行 t[0]: [4 7 4 2 9]\n",
      "倒数第一行 t[-1]:  [3 7 0 0 3]\n",
      "第1行第3列 t[1,3]: 4\n",
      "第1行第3列 t[1][3]: 4\n"
     ]
    }
   ],
   "source": [
    "tf.print(\"第0行 t[0]:\", t[0]) # 第0行\n",
    "tf.print(\"倒数第一行 t[-1]: \", t[-1]) # 倒数第一行\n",
    "\n",
    "#第1行第3列\n",
    "tf.print(\"第1行第3列 t[1,3]:\",t[1,3])\n",
    "tf.print(\"第1行第3列 t[1][3]:\", t[1][3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "第1行至第3行 \n",
      "\n",
      "t[1:4,:] \n",
      " [[9 1 2 4 7]\n",
      " [7 2 7 4 0]\n",
      " [9 6 9 7 2]]\n",
      "\n",
      " tf.slice切片法\n",
      "[[9 1 2 4 7]\n",
      " [7 2 7 4 0]\n",
      " [9 6 9 7 2]]\n"
     ]
    }
   ],
   "source": [
    "#第1行至第3行\n",
    "tf.print(\"第1行至第3行 \\n\")\n",
    "tf.print(\"t[1:4,:]\",'\\n', t[1:4,:])\n",
    "#tf.slice(input,begin_vector,size_vector)\n",
    "tf.print(\"\\n tf.slice切片法\")\n",
    "tf.print(tf.slice(t, [1,0], [3, 5]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[9 2 7]\n",
      " [7 7 0]\n",
      " [9 9 2]\n",
      " [3 0 3]]\n"
     ]
    }
   ],
   "source": [
    "#第1行至最后一行，第0列到最后一列每隔两列取一列\n",
    "tf.print(t[1:, ::2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[9 2]\n",
      " [7 7]\n",
      " [9 9]]\n"
     ]
    }
   ],
   "source": [
    "#第1行至倒数第二行，第0列到倒数第二列每隔两列取一列，\n",
    "# 切片终止位置的值取不到\n",
    "tf.print(t[1:-1, :-1:2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1 2]\n",
      " [0 0]]\n"
     ]
    }
   ],
   "source": [
    "#对变量来说，还可以使用索引和切片修改部分元素\n",
    "x = tf.Variable([[1, 2], [3, 4]], dtype=tf.float32)\n",
    "x[1, :].assign(tf.constant([0.0, 0.0]))\n",
    "tf.print(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[7 3 9]\n",
      "  [9 0 7]\n",
      "  [9 6 7]]\n",
      "\n",
      " [[1 3 3]\n",
      "  [0 8 1]\n",
      "  [3 1 0]]\n",
      "\n",
      " [[4 0 6]\n",
      "  [6 2 2]\n",
      "  [7 9 5]]]\n"
     ]
    }
   ],
   "source": [
    "a = tf.random.uniform([3,3,3],minval=0,maxval=10,dtype=tf.int32)\n",
    "tf.print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[3 0 6]\n",
      " [3 8 1]\n",
      " [0 2 9]]\n"
     ]
    }
   ],
   "source": [
    "#省略号可以表示多个冒号\n",
    "tf.print(a[...,1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以上切片方式相对规则，对于不规则的切片提取,可以使用tf.gather, tf.gather_nd, tf.boolean_mask。\n",
    "\n",
    "* tf.gather: 类似于数组的索引，可以把向量中某些索引值提取出来，得到新的向量，适用于要提取的索引为不连续的情况。根据indices从params的指定轴axis索引元素(类似于仅能在指定轴进行一维索引).\n",
    "\n",
    "* tf.gather_nd:将params索引为indices指定形状的切片数组中(indices代表索引后的数组形状) indices将切片定义为params的前N个维度，其中N = indices.shape [-1]\n",
    "    \n",
    "* tf.greater 判断函数。首先张量x和张量y的尺寸要相同，输出的tf.greater(x, y)也是一个和x，y尺寸相同的张量。如果x的某个元素比y中对应位置的元素大，则tf.greater(x, y)对应位置返回True，否则返回False。与此类似的函数还有tf.greater_equal。\n",
    "\n",
    "考虑班级成绩册的例子，有4个班级，每个班级10个学生，每个学生7门科目成绩。可以用一个4*10*7的张量来表示。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[52 82 66 ... 17 86 14]\n",
      "  [8 36 94 ... 13 78 41]\n",
      "  [77 53 51 ... 22 91 56]\n",
      "  ...\n",
      "  [11 19 26 ... 89 86 68]\n",
      "  [60 72 0 ... 11 26 15]\n",
      "  [24 99 38 ... 97 44 74]]\n",
      "\n",
      " [[79 73 73 ... 35 3 81]\n",
      "  [83 36 31 ... 75 38 85]\n",
      "  [54 26 67 ... 60 68 98]\n",
      "  ...\n",
      "  [20 5 18 ... 32 45 3]\n",
      "  [72 52 81 ... 88 41 20]\n",
      "  [0 21 89 ... 53 10 90]]\n",
      "\n",
      " [[52 80 22 ... 29 25 60]\n",
      "  [78 71 54 ... 43 98 81]\n",
      "  [21 66 53 ... 97 75 77]\n",
      "  ...\n",
      "  [6 74 3 ... 53 65 43]\n",
      "  [98 36 72 ... 33 36 81]\n",
      "  [61 78 70 ... 7 59 21]]\n",
      "\n",
      " [[56 57 45 ... 23 15 3]\n",
      "  [35 8 82 ... 11 59 97]\n",
      "  [44 6 99 ... 81 60 27]\n",
      "  ...\n",
      "  [76 26 35 ... 51 8 17]\n",
      "  [33 52 53 ... 78 37 31]\n",
      "  [71 27 44 ... 0 52 16]]]\n"
     ]
    }
   ],
   "source": [
    "scores = tf.random.uniform((4,10,7),minval=0,maxval=100,dtype=tf.int32)\n",
    "tf.print(scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[52 82 66 ... 17 86 14]\n",
      "  [24 80 70 ... 72 63 96]\n",
      "  [24 99 38 ... 97 44 74]]\n",
      "\n",
      " [[79 73 73 ... 35 3 81]\n",
      "  [46 10 94 ... 23 18 92]\n",
      "  [0 21 89 ... 53 10 90]]\n",
      "\n",
      " [[52 80 22 ... 29 25 60]\n",
      "  [19 12 23 ... 87 86 25]\n",
      "  [61 78 70 ... 7 59 21]]\n",
      "\n",
      " [[56 57 45 ... 23 15 3]\n",
      "  [6 41 79 ... 97 43 13]\n",
      "  [71 27 44 ... 0 52 16]]]\n"
     ]
    }
   ],
   "source": [
    "# 抽取每个班级第0个学生，第5个学生，第9个学生的全部成绩\n",
    "p = tf.gather(scores, [0,5,9], axis=1)\n",
    "tf.print(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[82 55 14]\n",
      "  [80 46 96]\n",
      "  [99 58 74]]\n",
      "\n",
      " [[73 48 81]\n",
      "  [10 38 92]\n",
      "  [21 86 90]]\n",
      "\n",
      " [[80 57 60]\n",
      "  [12 34 25]\n",
      "  [78 71 21]]\n",
      "\n",
      " [[57 75 3]\n",
      "  [41 47 13]\n",
      "  [27 96 16]]]\n"
     ]
    }
   ],
   "source": [
    "#抽取每个班级第0个学生，第5个学生，第9个学生的第1门课程，第3门课程，第6门课程成绩\n",
    "q = tf.gather(tf.gather(scores,[0,5,9],axis=1),[1,3,6],axis=2)\n",
    "tf.print(q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[52 82 66 ... 17 86 14]\n",
      " [99 94 46 ... 1 63 41]\n",
      " [46 83 70 ... 90 85 17]]\n"
     ]
    }
   ],
   "source": [
    "# 抽取第0个班级第0个学生，第2个班级的第4个学生，第3个班级的第6个学生的全部成绩\n",
    "# indices的长度为采样样本的个数，每个元素为采样位置的坐标\n",
    "s = tf.gather_nd(scores,indices = [(0,0),(2,4),(3,6)])\n",
    "tf.print(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[52 82 66 ... 17 86 14]\n",
      "  [24 80 70 ... 72 63 96]\n",
      "  [24 99 38 ... 97 44 74]]\n",
      "\n",
      " [[79 73 73 ... 35 3 81]\n",
      "  [46 10 94 ... 23 18 92]\n",
      "  [0 21 89 ... 53 10 90]]\n",
      "\n",
      " [[52 80 22 ... 29 25 60]\n",
      "  [19 12 23 ... 87 86 25]\n",
      "  [61 78 70 ... 7 59 21]]\n",
      "\n",
      " [[56 57 45 ... 23 15 3]\n",
      "  [6 41 79 ... 97 43 13]\n",
      "  [71 27 44 ... 0 52 16]]]\n"
     ]
    }
   ],
   "source": [
    "#抽取每个班级第0个学生，第5个学生，第9个学生的全部成绩\n",
    "p = tf.boolean_mask(scores,[True,False,False,False,False,\n",
    "                            True,False,False,False,True],axis=1)\n",
    "tf.print(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[52 82 66 ... 17 86 14]\n",
      " [99 94 46 ... 1 63 41]\n",
      " [46 83 70 ... 90 85 17]]\n"
     ]
    }
   ],
   "source": [
    "#抽取第0个班级第0个学生，第2个班级的第4个学生，第3个班级的第6个学生的全部成绩\n",
    "s = tf.boolean_mask(scores,\n",
    "    [[True,False,False,False,False,False,False,False,False,False],\n",
    "     [False,False,False,False,False,False,False,False,False,False],\n",
    "     [False,False,False,False,True,False,False,False,False,False],\n",
    "     [False,False,False,False,False,False,True,False,False,False]])\n",
    "tf.print(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-1 1 -1]\n",
      " [2 2 -2]\n",
      " [3 -3 3]] \n",
      "\n",
      "[-1 -1 -2 -3] \n",
      "\n",
      "[-1 -1 -2 -3]\n"
     ]
    }
   ],
   "source": [
    "#利用tf.boolean_mask可以实现布尔索引\n",
    "\n",
    "#找到矩阵中小于0的元素\n",
    "c = tf.constant([[-1,1,-1],[2,2,-2],[3,-3,3]],dtype=tf.float32)\n",
    "tf.print(c,\"\\n\")\n",
    "\n",
    "tf.print(tf.boolean_mask(c,c<0),\"\\n\") \n",
    "tf.print(c[c<0]) #布尔索引，为boolean_mask的语法糖形式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "以上这些方法仅能提取张量的部分元素值，但不能更改张量的部分元素值得到新的张量。\n",
    "\n",
    "如果要通过修改张量的部分元素值得到新的张量，可以使用tf.where和tf.scatter_nd。\n",
    "\n",
    "* tf.where可以理解为if的张量版本，此外它还可以用于找到满足条件的所有元素的位置坐标。\n",
    "\n",
    "* tf.scatter_nd的作用和tf.gather_nd有些相反，tf.gather_nd用于收集张量的给定位置的元素，tf.scatter_nd可以将某些值插入到一个给定shape的全0的张量的指定位置处。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(4, 2), dtype=int64, numpy=\n",
       "array([[0, 0],\n",
       "       [0, 2],\n",
       "       [1, 2],\n",
       "       [2, 1]])>"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#如果where只有一个参数，将返回所有满足条件的位置坐标\n",
    "indices = tf.where(c<0)\n",
    "indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[0 1 -1]\n",
      " [2 2 -2]\n",
      " [3 0 3]]\n"
     ]
    }
   ],
   "source": [
    "#将张量的第[0,0]和[2,1]两个位置元素替换为0得到新的张量\n",
    "d = c - tf.scatter_nd([[0,0],[2,1]],[c[0,0],c[2,1]],c.shape)\n",
    "tf.print(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(3, 3), dtype=float32, numpy=\n",
       "array([[-1.,  0., -1.],\n",
       "       [ 0.,  0., -2.],\n",
       "       [ 0., -3.,  0.]], dtype=float32)>"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#scatter_nd的作用和gather_nd有些相反\n",
    "#可以将某些值插入到一个给定shape的全0的张量的指定位置处。\n",
    "indices = tf.where(c<0)\n",
    "tf.scatter_nd(indices,tf.gather_nd(c,indices),c.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 三、维度变换\n",
    "\n",
    "维度变换相关函数主要有 tf.reshape, tf.squeeze, tf.expand_dims, tf.transpose.\n",
    "\n",
    "* tf.reshape 可以改变张量的形状。\n",
    "\n",
    "* tf.squeeze 可以减少维度。\n",
    "\n",
    "* tf.expand_dims 可以增加维度。\n",
    "\n",
    "* tf.transpose 可以交换维度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TensorShape([1, 3, 3, 2])\n",
      "[[[[100 44]\n",
      "   [181 14]\n",
      "   [90 53]]\n",
      "\n",
      "  [[205 141]\n",
      "   [14 24]\n",
      "   [239 46]]\n",
      "\n",
      "  [[225 174]\n",
      "   [212 78]\n",
      "   [14 144]]]]\n"
     ]
    }
   ],
   "source": [
    "a = tf.random.uniform(shape=[1,3,3,2],\n",
    "                      minval=0,maxval=255,dtype=tf.int32)\n",
    "tf.print(a.shape)\n",
    "tf.print(a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TensorShape([3, 6])\n",
      "[[100 44 181 14 90 53]\n",
      " [205 141 14 24 239 46]\n",
      " [225 174 212 78 14 144]]\n"
     ]
    }
   ],
   "source": [
    "# 改成 （3,6）形状的张量\n",
    "b = tf.reshape(a, [3, 6])\n",
    "tf.print(b.shape)\n",
    "tf.print(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[[100 44]\n",
      "   [181 14]\n",
      "   [90 53]]\n",
      "\n",
      "  [[205 141]\n",
      "   [14 24]\n",
      "   [239 46]]\n",
      "\n",
      "  [[225 174]\n",
      "   [212 78]\n",
      "   [14 144]]]]\n"
     ]
    }
   ],
   "source": [
    "# 改回成 [1,3,3,2] 形状的张量\n",
    "c = tf.reshape(b,[1,3,3,2])\n",
    "tf.print(c)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如果张量在某个维度上只有一个元素，利用tf.squeeze可以消除这个维度。和tf.reshape相似，它本质上不会改变张量元素的存储顺序。\n",
    "\n",
    "张量的各个元素在内存中是线性存储的，其一般规律是，同一层级中的相邻元素的物理地址也相邻。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TensorShape([3, 3, 2])\n",
      "[[[100 44]\n",
      "  [181 14]\n",
      "  [90 53]]\n",
      "\n",
      " [[205 141]\n",
      "  [14 24]\n",
      "  [239 46]]\n",
      "\n",
      " [[225 174]\n",
      "  [212 78]\n",
      "  [14 144]]]\n"
     ]
    }
   ],
   "source": [
    "s = tf.squeeze(a)\n",
    "tf.print(s.shape)\n",
    "tf.print(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[[100 44]\n",
      "   [181 14]\n",
      "   [90 53]]\n",
      "\n",
      "  [[205 141]\n",
      "   [14 24]\n",
      "   [239 46]]\n",
      "\n",
      "  [[225 174]\n",
      "   [212 78]\n",
      "   [14 144]]]]\n",
      "TensorShape([1, 3, 3, 2])\n"
     ]
    }
   ],
   "source": [
    "d = tf.expand_dims(s, axis=0) #在第0维插入长度为1的一个维度\n",
    "tf.print(d)\n",
    "tf.print(d.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[[100 44]]\n",
      "\n",
      "  [[181 14]]\n",
      "\n",
      "  [[90 53]]]\n",
      "\n",
      "\n",
      " [[[205 141]]\n",
      "\n",
      "  [[14 24]]\n",
      "\n",
      "  [[239 46]]]\n",
      "\n",
      "\n",
      " [[[225 174]]\n",
      "\n",
      "  [[212 78]]\n",
      "\n",
      "  [[14 144]]]]\n",
      "TensorShape([3, 3, 1, 2])\n"
     ]
    }
   ],
   "source": [
    "d = tf.expand_dims(s, axis=2) #在第0维插入长度为1的一个维度\n",
    "tf.print(d)\n",
    "tf.print(d.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "tf.transpose可以交换张量的维度，与tf.reshape不同，它会改变张量元素的存储顺序。\n",
    "tf.transpose常用于图片存储格式的变换上。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TensorShape([100, 600, 600, 4])\n",
      "TensorShape([4, 600, 600, 100])\n"
     ]
    }
   ],
   "source": [
    "# Batch,Height,Width,Channel\n",
    "a = tf.random.uniform(shape=[100,600,600,4],minval=0,maxval=255,dtype=tf.int32)\n",
    "tf.print(a.shape)\n",
    "\n",
    "# 转换成 Channel,Height,Width,Batch\n",
    "s= tf.transpose(a, perm=[3,1,2,0])\n",
    "tf.print(s.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 四、合并分割\n",
    "\n",
    "和numpy类似，可以用：\n",
    "* tf.concat和tf.stack方法对多个张量进行合并；\n",
    "* tf.split方法把一个张量分割成多个张量。\n",
    "\n",
    "tf.concat和tf.stack有略微的区别：\n",
    "* tf.concat是连接，不会增加维度；\n",
    "* tf.stack是堆叠，会增加维度。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(6, 2), dtype=float32, numpy=\n",
       "array([[ 1.,  2.],\n",
       "       [ 3.,  4.],\n",
       "       [ 5.,  6.],\n",
       "       [ 7.,  8.],\n",
       "       [ 9., 10.],\n",
       "       [11., 12.]], dtype=float32)>"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a = tf.constant([[1.0,2.0],[3.0,4.0]])\n",
    "b = tf.constant([[5.0,6.0],[7.0,8.0]])\n",
    "c = tf.constant([[9.0,10.0],[11.0,12.0]])\n",
    "\n",
    "tf.concat([a,b,c],axis = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 6), dtype=float32, numpy=\n",
       "array([[ 1.,  2.,  5.,  6.,  9., 10.],\n",
       "       [ 3.,  4.,  7.,  8., 11., 12.]], dtype=float32)>"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.concat([a,b,c], axis = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TensorShape([6, 2])\n",
      "TensorShape([2, 6])\n"
     ]
    }
   ],
   "source": [
    "tf.print((tf.concat([a,b,c],axis = 0)).shape)\n",
    "tf.print((tf.concat([a,b,c],axis = 1)).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[1 2]\n",
      "  [3 4]]\n",
      "\n",
      " [[5 6]\n",
      "  [7 8]]\n",
      "\n",
      " [[9 10]\n",
      "  [11 12]]]\n",
      "TensorShape([3, 2, 2])\n"
     ]
    }
   ],
   "source": [
    "t = tf.stack([a,b,c])\n",
    "tf.print(t)\n",
    "tf.print(t.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[[1 2]\n",
      "  [5 6]\n",
      "  [9 10]]\n",
      "\n",
      " [[3 4]\n",
      "  [7 8]\n",
      "  [11 12]]]\n",
      "TensorShape([2, 3, 2])\n"
     ]
    }
   ],
   "source": [
    "t = tf.stack([a,b,c], axis=1)\n",
    "tf.print(t)\n",
    "tf.print(t.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = tf.constant([[1.0,2.0],[3.0,4.0]])\n",
    "b = tf.constant([[5.0,6.0],[7.0,8.0]])\n",
    "c = tf.constant([[9.0,10.0],[11.0,12.0]])\n",
    "\n",
    "c = tf.concat([a,b,c],axis = 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "tf.split是tf.concat的逆运算，可以指定分割份数平均分割，也可以通过指定每份的记录数量进行分割。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       " array([[1., 2.],\n",
       "        [3., 4.]], dtype=float32)>,\n",
       " <tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       " array([[5., 6.],\n",
       "        [7., 8.]], dtype=float32)>,\n",
       " <tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       " array([[ 9., 10.],\n",
       "        [11., 12.]], dtype=float32)>]"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#tf.split(value, num_or_size_splits, axis)\n",
    "tf.split(c, 3, axis=0)  #指定分割份数，平均分割"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       " array([[1., 2.],\n",
       "        [3., 4.]], dtype=float32)>,\n",
       " <tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       " array([[5., 6.],\n",
       "        [7., 8.]], dtype=float32)>,\n",
       " <tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       " array([[ 9., 10.],\n",
       "        [11., 12.]], dtype=float32)>]"
      ]
     },
     "execution_count": 82,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.split(c,[2,2,2],axis = 0) #指定每份的记录数量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
